#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#include <vector>

template <typename scalar_t>
__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {
  // standard log-sum-exp trick is used here to provide better numerical stability
  return (a >= b) ? a + std::log1p(exp(b - a)) : b + std::log1p(exp(a - b));
}

// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.

// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.

// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_forward(const scalar_t* x, const int* label, const int* audLen, const int* txtLen,
                                        const int64_t* batchOffset,
                                        int64_t dictSize,  // 64-bit indexing for data tensor
                                        int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,
                                        acc_t* alpha, acc_t* beta, scalar_t* loss) {
  const int batch = blockIdx.y;
  const int tid = threadIdx.x;
  const auto myFLen = audLen[batch];
  // Note that start of the sentence is added as 1 here
  const auto myGLen = txtLen[batch] + 1;
  const auto myLabel = label + batch * (maxGLen - 1);
  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;
  const int64_t myStrideT = packedInput ? myGLen : maxGLen;
  const scalar_t* myX = x + myBatchOffset * dictSize;
  int u = tid;

  if (blockIdx.x == 0) {
    // alpha path
    acc_t* myAlpha = alpha + batch * maxFLen * maxGLen;
    if (u == 0) myAlpha[0] = 0;
    __syncthreads();

    for (int64_t step = 1; step < myFLen + myGLen - 1; ++step) {
      // Move along the diagonal wavefront to leverage available parallelism
      for (u = tid; u < myGLen; u += blockDim.x) {
        int64_t t = step - u;
        if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
          // Eq(16) in [1]
          if (u == 0) {
            // alpha(t, u) = alpha(t-1, u) * null(t-1, u)
            myAlpha[t * maxGLen + u] = myAlpha[(t - 1) * maxGLen] + myX[((t - 1) * myStrideT) * dictSize + blankIdx];
          } else if (t == 0) {
            // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
            myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];
          } else {
            // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
            acc_t current = myAlpha[(t - 1) * maxGLen + u] + myX[((t - 1) * myStrideT + u) * dictSize + blankIdx];
            acc_t next = myAlpha[t * maxGLen + u - 1] + myX[(t * myStrideT + u - 1) * dictSize + myLabel[u - 1]];
            myAlpha[t * maxGLen + u] = logSumExp(next, current);
          }
        }
      }
      __syncthreads();
    }
  } else if (blockIdx.x == 1) {
    // beta path
    acc_t* myBeta = beta + batch * maxFLen * maxGLen;
    if (u == 0) {
      myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];
    }
    __syncthreads();

    for (int64_t step = myFLen + myGLen - 3; step >= 0; --step) {
      for (u = tid; u < myGLen; u += blockDim.x) {
        int64_t t = step - u;
        if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
          // Eq(18) in [1]
          if (u == myGLen - 1) {
            // beta(t, u) = beta(t+1, u) * null(t, u)
            myBeta[t * maxGLen + u] = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx];
          } else if (t == myFLen - 1) {
            // beta(t, u) = beta(t, u+1) * y(t, u)
            myBeta[t * maxGLen + u] = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]];
          } else {
            // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
            acc_t current = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx];
            acc_t next = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]];
            myBeta[t * maxGLen + u] = logSumExp(next, current);
          }
        }
      }
      __syncthreads();
    }
    if (tid == 0) loss[batch] = -myBeta[0];
  }
}

// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.

// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.

// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, int batchLdSize>
__global__ void transducer_loss_batch_load_forward(const scalar_t* x, const int* label, const int* audLen,
                                                   const int* txtLen, const int64_t* batchOffset, int64_t dictSize,
                                                   int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,
                                                   acc_t* alpha, acc_t* beta, scalar_t* loss) {
  const int batch = blockIdx.y;
  int u = threadIdx.x;
  const auto myFLen = audLen[batch];
  const auto myGLen = txtLen[batch] + 1;
  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;
  const int64_t myStrideT = packedInput ? myGLen : maxGLen;
  const scalar_t* myX = x + myBatchOffset * dictSize;
  scalar_t next[batchLdSize], current[batchLdSize];
  extern __shared__ char smem8[];
  auto smem = reinterpret_cast<acc_t*>(smem8);

  if (blockIdx.x == 0) {
    // alpha path
    acc_t* myAlpha = alpha + batch * maxFLen * maxGLen;
    // two SMEM regions for double buffering read and write data to avoid data race
    acc_t* const sharedAlpha[2] = {smem, smem + maxGLen};

    sharedAlpha[0][u] = 0;
    __syncthreads();

    if (u == 0) myAlpha[0] = 0;

    auto myAlphaLabel = (u == 0) ? 0 : label[batch * (maxGLen - 1) + u - 1];
    // register used to pass value to the next step for the same thread
    acc_t prvStepAlpha = 0;
    for (int64_t step = 1; step < myFLen + myGLen - 1 + batchLdSize; step += batchLdSize) {
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
      for (int i = 0; i < batchLdSize; ++i) {
        if (step + i < myFLen + myGLen - 1) {
          // index computing
          int64_t t = step + i - u;
          int64_t currentId = ((t - 1) * myStrideT + u) * dictSize + blankIdx;
          int64_t nextId = (t * myStrideT + u - 1) * dictSize + myAlphaLabel;
          // main loading loop
          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
            if (u == 0) {
              current[i] = myX[currentId];
            } else if (t == 0) {
              next[i] = myX[nextId];
            } else {
              current[i] = myX[currentId];
              next[i] = myX[nextId];
            }
          }
        }
      }
      // main computing loop
      for (int i = 0; i < batchLdSize; ++i) {
        // swap the pointer for double buffering
        auto sharedAlphaRd = sharedAlpha[(step + i - 1) % 2];
        auto sharedAlphaWr = sharedAlpha[(step + i) % 2];
        if (step + i < myFLen + myGLen - 1) {
          int64_t t = step + i - u;
          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
            // Eq(16) in [1]
            if (u == 0)
              prvStepAlpha = prvStepAlpha + current[i];
            else if (t == 0)
              prvStepAlpha = sharedAlphaRd[u - 1] + next[i];
            else
              prvStepAlpha = logSumExp(prvStepAlpha + current[i], sharedAlphaRd[u - 1] + next[i]);
            sharedAlphaWr[u] = prvStepAlpha;
            myAlpha[t * maxGLen + u] = prvStepAlpha;
          }
        }
        __syncthreads();
      }
    }
  } else if (blockIdx.x == 1) {
    // beta path
    acc_t* myBeta = beta + batch * maxFLen * maxGLen;
    // two SMEM regions for double buffering read and write data to avoid data race
    acc_t* const sharedBeta[2] = {smem, smem + maxGLen};
    sharedBeta[0][u] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];
    __syncthreads();

    auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch * (maxGLen - 1) + u];
    // register used to pass value to the next step for the same thread
    acc_t prvStepBeta = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];
    if (u == 0) myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = prvStepBeta;

    for (int64_t step = 1; step < myFLen + myGLen - 1; step += batchLdSize) {
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
      for (int i = 0; i < batchLdSize; ++i) {
        if (step + i < myFLen + myGLen - 1) {
          // index computing
          int64_t t = myFLen + myGLen - (step + i) - 2 - u;
          int64_t currentId = (t * myStrideT + u) * dictSize + blankIdx;
          int64_t nextId = (t * myStrideT + u) * dictSize + myBetaLabel;
          // main loading loop
          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
            if (u == myGLen - 1) {
              current[i] = myX[currentId];
            } else if (t == myFLen - 1) {
              next[i] = myX[nextId];
            } else {
              current[i] = myX[currentId];
              next[i] = myX[nextId];
            }
          }
        }
      }
      // main computing loop
      for (int i = 0; i < batchLdSize; ++i) {
        // swap the pointer for double buffering
        auto sharedBetaRd = sharedBeta[(step + i - 1) % 2];
        auto sharedBetaWr = sharedBeta[(step + i) % 2];
        if (step + i < myFLen + myGLen - 1) {
          int64_t t = myFLen + myGLen - (step + i) - 2 - u;
          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {
            // Eq(18) in [1]
            if (u == myGLen - 1)
              prvStepBeta = prvStepBeta + current[i];
            else if (t == myFLen - 1)
              prvStepBeta = sharedBetaRd[u + 1] + next[i];
            else
              prvStepBeta = logSumExp(prvStepBeta + current[i], sharedBetaRd[u + 1] + next[i]);
            sharedBetaWr[u] = prvStepBeta;
            myBeta[t * maxGLen + u] = prvStepBeta;
          }
        }
        __syncthreads();
      }
    }
    if (u == 0) loss[batch] = -prvStepBeta;
  }
}

// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.

// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.

// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,
                                         const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta,
                                         const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx,
                                         int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) {
  const int tid = threadIdx.x;
  const int t = blockIdx.x;
  const int batch = blockIdx.y;
  const int64_t myFLen = audLen[batch];
  const int64_t myGLen = txtLen[batch] + 1;
  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;
  const int64_t myStrideT = packedInput ? myGLen : maxGLen;
  auto myX = x + (myBatchOffset + t * myStrideT) * dictSize;
  auto myAlpha = alpha + batch * maxFLen * maxGLen;
  auto myBeta = beta + batch * maxFLen * maxGLen;
  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT) * dictSize;
  auto myLabel = label + batch * (maxGLen - 1);

  int64_t u = tid;
  while (t < myFLen and u < myGLen) {
    // Do the update
    // loss = -ln(Pr(y*|x))
    acc_t grad = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];
    if (u != myGLen - 1)
      myXGrad[u * dictSize + myLabel[u]] =
          -std::exp(grad + myBeta[t * maxGLen + u + 1] + myX[u * dictSize + myLabel[u]]);
    if (t == myFLen - 1 and u == myGLen - 1)
      myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myX[u * dictSize + blankIdx]);
    else if (t != myFLen - 1)
      myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myBeta[(t + 1) * maxGLen + u] + myX[u * dictSize + blankIdx]);

    u += blockDim.x;
  }
}

// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time

// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_fused_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,
                                               const int* txtLen, const int* label, const acc_t* alpha,
                                               const acc_t* beta, const int64_t* batchOffset, int64_t dictSize,
                                               int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,
                                               scalar_t* xGrad) {
  const int tid = threadIdx.x;
  const int u = blockIdx.x;
  const int t = blockIdx.y;
  const int batch = blockIdx.z;
  const int64_t myFLen = audLen[batch];
  const int64_t myGLen = txtLen[batch] + 1;
  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;
  const int64_t myStrideT = packedInput ? myGLen : maxGLen;

  __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize;

  if (t < myFLen and u < myGLen) {
    auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize;
    auto myAlpha = alpha + batch * maxFLen * maxGLen;
    auto myBeta = beta + batch * maxFLen * maxGLen;
    auto myLabel = label + batch * (maxGLen - 1);

    // load and store shared variables in SMEM
    if (tid == 0) {
      commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];
      myBetaTU = myBeta[t * maxGLen + u];
      myBetaTUp1 = myBeta[t * maxGLen + u + 1];
      myBetaTp1U = myBeta[(t + 1) * maxGLen + u];
      myLabelShared = myLabel[u];
    }

    __syncthreads();

    for (int64_t h = tid; h < dictSize; h += blockDim.x) {
      // Do the update
      acc_t grad = commonFactor + myX[h];  // loss = -ln(Pr(y*|x))
      acc_t myGrad = std::exp(grad + myBetaTU);
      if (u != myGLen - 1 and h == myLabelShared) {
        myGrad -= std::exp(grad + myBetaTUp1);
      } else if (h == blankIdx) {
        if (t == myFLen - 1 and u == myGLen - 1)
          myGrad -= std::exp(grad);
        else if (t != myFLen - 1)
          myGrad -= std::exp(grad + myBetaTp1U);
      }
      myXGrad[h] = myGrad;
    }
  } else if (!packedInput) {
    // In non-pack mode, need to make sure the gradients for don't-care regions are zero.
    for (int64_t h = tid; h < dictSize; h += blockDim.x) {
      myXGrad[h] = 0;
    }
  }
}

// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time

// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,
                                                   const int* txtLen, const int* label, const acc_t* alpha,
                                                   const acc_t* beta, const int64_t* batchOffset, int64_t dictSize,
                                                   int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,
                                                   scalar_t* xGrad) {
  const int tid = threadIdx.x;
  const int u = blockIdx.x;
  const int t = blockIdx.y;
  const int batch = blockIdx.z;
  const int64_t myFLen = audLen[batch];
  const int64_t myGLen = txtLen[batch] + 1;
  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;
  const int64_t myStrideT = packedInput ? myGLen : maxGLen;

  __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize;
  auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize;
  auto myAlpha = alpha + batch * maxFLen * maxGLen;
  auto myBeta = beta + batch * maxFLen * maxGLen;
  auto myLabel = label + batch * (maxGLen - 1);

  // Variabels for vectorization
  scalar_t myXBuffer[V], myXGradBuffer[V];
  auto myXVec = reinterpret_cast<vec_t const*>(myX);
  auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
  auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
  auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
  if (t < myFLen and u < myGLen) {
    // load and store shared variables in SMEM
    if (tid == 0) {
      commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];
      myBetaTU = myBeta[t * maxGLen + u];
      if (t != myFLen - 1) myBetaTp1U = myBeta[(t + 1) * maxGLen + u];
      if (u != myGLen - 1) {
        myBetaTUp1 = myBeta[t * maxGLen + u + 1];
        myLabelShared = myLabel[u];
      }
    }

    __syncthreads();

#pragma unroll
    for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) {
      // Load myX in a vector form
      *myXBufferVec = myXVec[h0 / V];
// Do the update for a vector of input
#pragma unroll
      for (int i = 0; i < V; ++i) {
        auto h = h0 + i;
        acc_t grad = commonFactor + myXBuffer[i];  // loss = -ln(Pr(y*|x))
        acc_t myGrad = std::exp(grad + myBetaTU);
        if (u != myGLen - 1 and h == myLabelShared) {
          myGrad -= std::exp(grad + myBetaTUp1);
        } else if (h == blankIdx) {
          if (t == myFLen - 1 and u == myGLen - 1)
            myGrad -= std::exp(grad);
          else if (t != myFLen - 1)
            myGrad -= std::exp(grad + myBetaTp1U);
        }
        myXGradBuffer[i] = myGrad;
      }

      // Store myXGrad in a vector form
      myXGradVec[h0 / V] = *myXGradBufferVec;
    }
  } else if (!packedInput) {
    // In non-pack mode, need to make sure the gradients for don't-care regions are zero.
    for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) {
      myXGradVec[h0 / V] = 0;
    }
  }
}

std::vector<torch::Tensor> transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen,
                                                        torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen,
                                                        int blankIdx, int opt, bool packedInput) {
  auto scalarType = x.scalar_type();
  auto tensorOpt = x.options();
  const int batchSize = label.size(0);
  const int maxGLen = label.size(1) + 1;
  const int dictSize = x.size(-1);

  TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, "Expected blank index to be in the range of 0 to ", dictSize - 1,
              ", but got ", blankIdx);
  TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, "Got an invalid optimization level ", opt);

  // The data type of alpha and beta will be resolved at dispatch time,
  // hence defined here and assigned later
  torch::Tensor alpha;
  torch::Tensor beta;
  torch::Tensor loss = torch::empty({batchSize}, tensorOpt);
  const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
  const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
  const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;
  const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      scalarType, "transducer_loss_cuda_forward", ([&] {
        // resolve accumulation type
        using acc_t = at::acc_type<scalar_t, true>;
        auto accType = c10::CppTypeToScalarType<acc_t>::value;
        auto accTensorOpt = tensorOpt.dtype(accType);
        alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
        beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);

        // decide what kernel to launch based on the problem size
        // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
        // kernel.
        const auto smemSize = 2 * maxGLen * sizeof(acc_t);
        const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0
                                 : (opt == -1)                                               ? 1
                                                                                             : opt;
        const int threads = std::min(maxThreadPerBlock, maxGLen);
        const dim3 blocks(2, batchSize, 1);

        if (optFallBack == 0)
          transducer_loss_forward<<<blocks, threads, 0, stream>>>(
              x.data_ptr<scalar_t>(), label.data_ptr<int>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),
              batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr<acc_t>(),
              beta.data_ptr<acc_t>(), loss.data_ptr<scalar_t>());
        else if (optFallBack == 1)
          transducer_loss_batch_load_forward<scalar_t, acc_t, 4><<<blocks, threads, smemSize, stream>>>(
              x.data_ptr<scalar_t>(), label.data_ptr<int>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),
              batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr<acc_t>(),
              beta.data_ptr<acc_t>(), loss.data_ptr<scalar_t>());
      }));
  C10_CUDA_CHECK(cudaGetLastError());

  return {alpha, beta, loss};
}

torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha,
                                            torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen,
                                            torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx,
                                            int opt, bool fuseSoftmaxBackward, bool packedInput) {
  auto dtype = x.scalar_type();
  torch::Tensor xGrad;
  const int batchSize = label.size(0);
  const int maxGLen = label.size(1) + 1;
  const int dictSize = x.size(-1);
  const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
  const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
  const int warpSize = deviceProperties->warpSize;
  const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (fuseSoftmaxBackward) {
    // alloc empty tensors for performance, hence need to ensure zeros are writtern to
    // don't-care region in the kernel.
    xGrad = torch::empty_like(x);

    // Would like each thread to work on 4 hidden units
    const int workPerThread = 4;
    // Don't want to have more than 128 threads per thread block
    const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);
    const int threads = std::min(maxThreadPerElmt, std::max(warpSize, (dictSize + workPerThread - 1) / workPerThread));
    const dim3 blocks(maxGLen, maxFLen, batchSize);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
        dtype, "transducer_loss_cuda_backward", ([&] {
          using vec_t = uint64_t;
          using acc_t = at::acc_type<scalar_t, true>;
          constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
          constexpr int vecAlignment = std::alignment_of<vec_t>::value;
          // if all input and output tensors meet the alignment requirement
          bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0 and
                          reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>()) % vecAlignment == 0;

          if (vectFactor > 1 and dictSize % vectFactor == 0 and memAlign) {
            transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor><<<blocks, threads, 0, stream>>>(
                x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),
                label.data_ptr<int>(), alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,
                blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());
          } else {
            transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
                x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),
                label.data_ptr<int>(), alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,
                blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());
          }
        }));
  } else {
    // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
    // the tensor with all zeros.
    xGrad = torch::zeros_like(x);
    // don't launch more threads than needed.
    const int threads = std::min(maxThreadPerBlock, maxGLen);
    const dim3 blocks(maxFLen, batchSize);
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
                                          using acc_t = at::acc_type<scalar_t, true>;
                                          transducer_loss_backward<<<blocks, threads, 0, stream>>>(
                                              x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(),
                                              audLen.data_ptr<int>(), txtLen.data_ptr<int>(), label.data_ptr<int>(),
                                              alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,
                                              blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());
                                        }));
  }
  C10_CUDA_CHECK(cudaGetLastError());

  return xGrad;
}
